{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "t09eeeR5prIJ"
   },
   "source": [
    "# TensorFlow2.0教程-使用RNN生成文本"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "BwpJ5IffzRG6"
   },
   "source": [
    "本教程演示了如何使用基于字符的RNN生成文本。我们将使用Andrej Karpathy的“循环神经网络的不合理有效性”中的莎士比亚写作数据集。给定来自该数据的一系列字符（“Shakespear”），训练模型以预测序列中的下一个字符（“e”）。通过重复调用模型可以生成更长的文本序列。\n",
    "\n",
    "本教程包含使用tf.keras实现的可运行代码和急切执行。以下是本教程中的模型训练了30个纪元时的示例输出，并以字符串“Q”开头：\n",
    "\n",
    "<pre>\n",
    "QUEENE:\n",
    "I had thought thou hadst a Roman; for the oracle,\n",
    "Thus by All bids the man against the word,\n",
    "Which are so weak of care, by old care done;\n",
    "Your children were in your holy love,\n",
    "And the precipitation through the bleeding throne.\n",
    "\n",
    "BISHOP OF ELY:\n",
    "Marry, and will, my lord, to weep in such a one were prettiest;\n",
    "Yet now I was adopted heir\n",
    "Of the world's lamentable day,\n",
    "To watch the next way with his father with his face?\n",
    "\n",
    "ESCALUS:\n",
    "The cause why then we are all resolved more sons.\n",
    "\n",
    "VOLUMNIA:\n",
    "O, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, it is no sin it should be dead,\n",
    "And love and pale as any will to that word.\n",
    "\n",
    "QUEEN ELIZABETH:\n",
    "But how long have I heard the soul for this world,\n",
    "And show his hands of life be proved to stand.\n",
    "\n",
    "PETRUCHIO:\n",
    "I say he look'd on, if I must be content\n",
    "To stay him from the fatal of our country's bliss.\n",
    "His lordship pluck'd from this sentence then for prey,\n",
    "And then let us twain, being the moon,\n",
    "were she such a case as fills m\n",
    "</pre>\n",
    "\n",
    "虽然有些句子是语法上的，但大多数句子都没有意义。该模型尚未学习单词的含义，但考虑：\n",
    "\n",
    "该模型基于字符。培训开始时，模型不知道如何拼写英语单词，或者单词甚至是文本单元。\n",
    "\n",
    "输出的结构类似于文本的播放块，通常以说话者名称开头，所有大写字母都类似于数据集。\n",
    "\n",
    "如下所示，模型是针对小批量文本（每个100个字符）进行训练的，并且仍然能够生成具有连贯结构的更长文本序列。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "srXC6pLGLwS6"
   },
   "source": [
    "## 开始"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "WGyKZj3bzf9p"
   },
   "source": [
    "### 导入TensorFlow"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 397
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 6412,
     "status": "ok",
     "timestamp": 1564756501128,
     "user": {
      "displayName": "Will Chen",
      "photoUrl": "",
      "userId": "01179718990779759737"
     },
     "user_tz": -480
    },
    "id": "yG_n40gFzf9s",
    "outputId": "08cf79a9-4f5b-4ee2-afb7-209179e1949b"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: tensorflow-gpu==2.0.0-beta1 in /usr/local/lib/python3.6/dist-packages (2.0.0b1)\n",
      "Requirement already satisfied: absl-py>=0.7.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-gpu==2.0.0-beta1) (0.7.1)\n",
      "Requirement already satisfied: termcolor>=1.1.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-gpu==2.0.0-beta1) (1.1.0)\n",
      "Requirement already satisfied: grpcio>=1.8.6 in /usr/local/lib/python3.6/dist-packages (from tensorflow-gpu==2.0.0-beta1) (1.15.0)\n",
      "Requirement already satisfied: tb-nightly<1.14.0a20190604,>=1.14.0a20190603 in /usr/local/lib/python3.6/dist-packages (from tensorflow-gpu==2.0.0-beta1) (1.14.0a20190603)\n",
      "Requirement already satisfied: protobuf>=3.6.1 in /usr/local/lib/python3.6/dist-packages (from tensorflow-gpu==2.0.0-beta1) (3.7.1)\n",
      "Requirement already satisfied: six>=1.10.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-gpu==2.0.0-beta1) (1.12.0)\n",
      "Requirement already satisfied: wheel>=0.26 in /usr/local/lib/python3.6/dist-packages (from tensorflow-gpu==2.0.0-beta1) (0.33.4)\n",
      "Requirement already satisfied: google-pasta>=0.1.6 in /usr/local/lib/python3.6/dist-packages (from tensorflow-gpu==2.0.0-beta1) (0.1.7)\n",
      "Requirement already satisfied: keras-applications>=1.0.6 in /usr/local/lib/python3.6/dist-packages (from tensorflow-gpu==2.0.0-beta1) (1.0.8)\n",
      "Requirement already satisfied: numpy<2.0,>=1.14.5 in /usr/local/lib/python3.6/dist-packages (from tensorflow-gpu==2.0.0-beta1) (1.16.4)\n",
      "Requirement already satisfied: gast>=0.2.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-gpu==2.0.0-beta1) (0.2.2)\n",
      "Requirement already satisfied: wrapt>=1.11.1 in /usr/local/lib/python3.6/dist-packages (from tensorflow-gpu==2.0.0-beta1) (1.11.2)\n",
      "Requirement already satisfied: keras-preprocessing>=1.0.5 in /usr/local/lib/python3.6/dist-packages (from tensorflow-gpu==2.0.0-beta1) (1.1.0)\n",
      "Requirement already satisfied: astor>=0.6.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-gpu==2.0.0-beta1) (0.8.0)\n",
      "Requirement already satisfied: tf-estimator-nightly<1.14.0.dev2019060502,>=1.14.0.dev2019060501 in /usr/local/lib/python3.6/dist-packages (from tensorflow-gpu==2.0.0-beta1) (1.14.0.dev2019060501)\n",
      "Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.6/dist-packages (from tb-nightly<1.14.0a20190604,>=1.14.0a20190603->tensorflow-gpu==2.0.0-beta1) (0.15.5)\n",
      "Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.6/dist-packages (from tb-nightly<1.14.0a20190604,>=1.14.0a20190603->tensorflow-gpu==2.0.0-beta1) (3.1.1)\n",
      "Requirement already satisfied: setuptools>=41.0.0 in /usr/local/lib/python3.6/dist-packages (from tb-nightly<1.14.0a20190604,>=1.14.0a20190603->tensorflow-gpu==2.0.0-beta1) (41.0.1)\n",
      "Requirement already satisfied: h5py in /usr/local/lib/python3.6/dist-packages (from keras-applications>=1.0.6->tensorflow-gpu==2.0.0-beta1) (2.8.0)\n"
     ]
    }
   ],
   "source": [
    "from __future__ import absolute_import, division, print_function, unicode_literals\n",
    "\n",
    "!pip install tensorflow-gpu==2.0.0-beta1\n",
    "import tensorflow as tf\n",
    "\n",
    "import numpy as np\n",
    "import os\n",
    "import time"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "EHDoRoc5PKWz"
   },
   "source": [
    "### 下载莎士比亚数据集\n",
    "更改以下行以在您自己的数据上运行此代码"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "pD_55cOxLkAb"
   },
   "outputs": [],
   "source": [
    "path_to_file = tf.keras.utils.get_file('shakespeare.txt', 'https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "UHjdCjDuSvX_"
   },
   "source": [
    "### 观察数据\n",
    "\n",
    "首先， 观察文字:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 35
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 8756,
     "status": "ok",
     "timestamp": 1564756503495,
     "user": {
      "displayName": "Will Chen",
      "photoUrl": "",
      "userId": "01179718990779759737"
     },
     "user_tz": -480
    },
    "id": "aavnuByVymwK",
    "outputId": "9eabfb6e-a151-4ef3-e8da-4c3df1c045bf"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Length of text: 1115394 characters\n"
     ]
    }
   ],
   "source": [
    "# Read, then decode for py2 compat.\n",
    "text = open(path_to_file, 'rb').read().decode(encoding='utf-8')\n",
    "# length of text is the number of characters in it\n",
    "print ('Length of text: {} characters'.format(len(text)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 287
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 8749,
     "status": "ok",
     "timestamp": 1564756503495,
     "user": {
      "displayName": "Will Chen",
      "photoUrl": "",
      "userId": "01179718990779759737"
     },
     "user_tz": -480
    },
    "id": "Duhg9NrUymwO",
    "outputId": "127f9e8c-71cc-4c8e-e09d-555e7f20164a"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "First Citizen:\n",
      "Before we proceed any further, hear me speak.\n",
      "\n",
      "All:\n",
      "Speak, speak.\n",
      "\n",
      "First Citizen:\n",
      "You are all resolved rather to die than to famish?\n",
      "\n",
      "All:\n",
      "Resolved. resolved.\n",
      "\n",
      "First Citizen:\n",
      "First, you know Caius Marcius is chief enemy to the people.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Take a look at the first 250 characters in text\n",
    "print(text[:250])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 35
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 8741,
     "status": "ok",
     "timestamp": 1564756503496,
     "user": {
      "displayName": "Will Chen",
      "photoUrl": "",
      "userId": "01179718990779759737"
     },
     "user_tz": -480
    },
    "id": "IlCgQBRVymwR",
    "outputId": "d430e2fd-8c5d-4798-eeec-8d24ce3dec5c"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "65 unique characters\n"
     ]
    }
   ],
   "source": [
    "# The unique characters in the file\n",
    "vocab = sorted(set(text))\n",
    "print ('{} unique characters'.format(len(vocab)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "rNnrKn_lL-IJ"
   },
   "source": [
    "## 处理文本"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "LFjSVAlWzf-N"
   },
   "source": [
    "### 矢量化文本\n",
    "在训练之前，我们需要将字符串映射到数字表示。创建两个查找表：一个将字符映射到数字，另一个用于数字到字符。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "IalZLbvOzf-F"
   },
   "outputs": [],
   "source": [
    "# Creating a mapping from unique characters to indices\n",
    "char2idx = {u:i for i, u in enumerate(vocab)}\n",
    "idx2char = np.array(vocab)\n",
    "\n",
    "text_as_int = np.array([char2idx[c] for c in text])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "tZfqhkYCymwX"
   },
   "source": [
    "现在我们有一个每个字符的整数表示。请注意，我们将字符映射为从0到的索引len(unique)。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 431
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 9081,
     "status": "ok",
     "timestamp": 1564756503855,
     "user": {
      "displayName": "Will Chen",
      "photoUrl": "",
      "userId": "01179718990779759737"
     },
     "user_tz": -480
    },
    "id": "FYyNlCNXymwY",
    "outputId": "7293ffef-0ed8-4ae6-9a54-3397f989b83c"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\n",
      "  '\\n':   0,\n",
      "  ' ' :   1,\n",
      "  '!' :   2,\n",
      "  '$' :   3,\n",
      "  '&' :   4,\n",
      "  \"'\" :   5,\n",
      "  ',' :   6,\n",
      "  '-' :   7,\n",
      "  '.' :   8,\n",
      "  '3' :   9,\n",
      "  ':' :  10,\n",
      "  ';' :  11,\n",
      "  '?' :  12,\n",
      "  'A' :  13,\n",
      "  'B' :  14,\n",
      "  'C' :  15,\n",
      "  'D' :  16,\n",
      "  'E' :  17,\n",
      "  'F' :  18,\n",
      "  'G' :  19,\n",
      "  ...\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "print('{')\n",
    "for char,_ in zip(char2idx, range(20)):\n",
    "    print('  {:4s}: {:3d},'.format(repr(char), char2idx[char]))\n",
    "print('  ...\\n}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 35
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 9073,
     "status": "ok",
     "timestamp": 1564756503855,
     "user": {
      "displayName": "Will Chen",
      "photoUrl": "",
      "userId": "01179718990779759737"
     },
     "user_tz": -480
    },
    "id": "l1VKcQHcymwb",
    "outputId": "0fac4352-89b7-4a9d-c2b4-8e9b38db5714"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "'First Citizen' ---- characters mapped to int ---- > [18 47 56 57 58  1 15 47 58 47 64 43 52]\n"
     ]
    }
   ],
   "source": [
    "# Show how the first 13 characters from the text are mapped to integers\n",
    "print ('{} ---- characters mapped to int ---- > {}'.format(repr(text[:13]), text_as_int[:13]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "bbmsf23Bymwe"
   },
   "source": [
    "### 预测任务"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "wssHQ1oGymwe"
   },
   "source": [
    "\n",
    "给定一个字符或一系列字符，最可能的下一个字符是什么？这是我们正在训练模型执行的任务。模型的输入将是一系列字符，我们训练模型以预测输出 - 每个时间步的后续字符。\n",
    "\n",
    "由于RNN维持一个取决于之前看到的元素的内部状态，给定此时计算的所有字符，下一个字符是什么？"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "hgsVvVxnymwf"
   },
   "source": [
    "### 创建培训示例和目标\n",
    "接下来将文本划分为示例序列。每个输入序列将包含seq_length文本中的字符。\n",
    "\n",
    "对于每个输入序列，相应的目标包含相同长度的文本，除了向右移动一个字符。\n",
    "\n",
    "所以把文本分成几块seq_length+1。例如，假设seq_length是4，我们的文本是“你好”。输入序列是“Hell”，目标序列是“ello”。\n",
    "\n",
    "为此，首先使用该tf.data.Dataset.from_tensor_slices函数将文本向量转换为字符索引流。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 107
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 9448,
     "status": "ok",
     "timestamp": 1564756504237,
     "user": {
      "displayName": "Will Chen",
      "photoUrl": "",
      "userId": "01179718990779759737"
     },
     "user_tz": -480
    },
    "id": "0UHJDA39zf-O",
    "outputId": "9e7df249-7f19-4689-804c-8705f79792a5"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "F\n",
      "i\n",
      "r\n",
      "s\n",
      "t\n"
     ]
    }
   ],
   "source": [
    "# The maximum length sentence we want for a single input in characters\n",
    "seq_length = 100\n",
    "examples_per_epoch = len(text)//seq_length\n",
    "\n",
    "# Create training examples / targets\n",
    "char_dataset = tf.data.Dataset.from_tensor_slices(text_as_int)\n",
    "\n",
    "for i in char_dataset.take(5):\n",
    "  print(idx2char[i.numpy()])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "-ZSYAcQV8OGP"
   },
   "source": [
    "该batch方法可以让我们轻松地将这些单个字符转换为所需大小的序列。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 107
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 9437,
     "status": "ok",
     "timestamp": 1564756504238,
     "user": {
      "displayName": "Will Chen",
      "photoUrl": "",
      "userId": "01179718990779759737"
     },
     "user_tz": -480
    },
    "id": "l4hkDU3i7ozi",
    "outputId": "a95fabae-b3fa-460f-b664-7d66d13cabeb"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "'First Citizen:\\nBefore we proceed any further, hear me speak.\\n\\nAll:\\nSpeak, speak.\\n\\nFirst Citizen:\\nYou '\n",
      "'are all resolved rather to die than to famish?\\n\\nAll:\\nResolved. resolved.\\n\\nFirst Citizen:\\nFirst, you k'\n",
      "\"now Caius Marcius is chief enemy to the people.\\n\\nAll:\\nWe know't, we know't.\\n\\nFirst Citizen:\\nLet us ki\"\n",
      "\"ll him, and we'll have corn at our own price.\\nIs't a verdict?\\n\\nAll:\\nNo more talking on't; let it be d\"\n",
      "'one: away, away!\\n\\nSecond Citizen:\\nOne word, good citizens.\\n\\nFirst Citizen:\\nWe are accounted poor citi'\n"
     ]
    }
   ],
   "source": [
    "sequences = char_dataset.batch(seq_length+1, drop_remainder=True)\n",
    "\n",
    "for item in sequences.take(5):\n",
    "  print(repr(''.join(idx2char[item.numpy()])))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "UbLcIPBj_mWZ"
   },
   "source": [
    "对于每个序列，复制并移动它以形成输入和目标文本，map方法是使用该方法将简单函数应用于每个批处理："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "9NGu-FkO_kYU"
   },
   "outputs": [],
   "source": [
    "def split_input_target(chunk):\n",
    "    input_text = chunk[:-1]\n",
    "    target_text = chunk[1:]\n",
    "    return input_text, target_text\n",
    "\n",
    "dataset = sequences.map(split_input_target)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "hiCopyGZymwi"
   },
   "source": [
    "打印第一个示例输入和目标值："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 53
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 9411,
     "status": "ok",
     "timestamp": 1564756504239,
     "user": {
      "displayName": "Will Chen",
      "photoUrl": "",
      "userId": "01179718990779759737"
     },
     "user_tz": -480
    },
    "id": "GNbw-iR0ymwj",
    "outputId": "4240109b-6f9f-46b2-8126-609fee3b7bf4"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Input data:  'First Citizen:\\nBefore we proceed any further, hear me speak.\\n\\nAll:\\nSpeak, speak.\\n\\nFirst Citizen:\\nYou'\n",
      "Target data: 'irst Citizen:\\nBefore we proceed any further, hear me speak.\\n\\nAll:\\nSpeak, speak.\\n\\nFirst Citizen:\\nYou '\n"
     ]
    }
   ],
   "source": [
    "for input_example, target_example in  dataset.take(1):\n",
    "  print ('Input data: ', repr(''.join(idx2char[input_example.numpy()])))\n",
    "  print ('Target data:', repr(''.join(idx2char[target_example.numpy()])))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "_33OHL3b84i0"
   },
   "source": [
    "这些矢量的每个索引作为一个时间步骤处理。对于时间步骤0的输入，模型接收“F”和trys的索引以预测“i”的索引作为下一个字符。在下一个时间步，它做同样的事情，但RNN除了当前输入字符之外还考虑前一步骤上下文。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 287
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 10188,
     "status": "ok",
     "timestamp": 1564756505026,
     "user": {
      "displayName": "Will Chen",
      "photoUrl": "",
      "userId": "01179718990779759737"
     },
     "user_tz": -480
    },
    "id": "0eBu9WZG84i0",
    "outputId": "0a467f77-f745-45b6-dc7f-559a4f4a05a5"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Step    0\n",
      "  input: 18 ('F')\n",
      "  expected output: 47 ('i')\n",
      "Step    1\n",
      "  input: 47 ('i')\n",
      "  expected output: 56 ('r')\n",
      "Step    2\n",
      "  input: 56 ('r')\n",
      "  expected output: 57 ('s')\n",
      "Step    3\n",
      "  input: 57 ('s')\n",
      "  expected output: 58 ('t')\n",
      "Step    4\n",
      "  input: 58 ('t')\n",
      "  expected output: 1 (' ')\n"
     ]
    }
   ],
   "source": [
    "for i, (input_idx, target_idx) in enumerate(zip(input_example[:5], target_example[:5])):\n",
    "    print(\"Step {:4d}\".format(i))\n",
    "    print(\"  input: {} ({:s})\".format(input_idx, repr(idx2char[input_idx])))\n",
    "    print(\"  expected output: {} ({:s})\".format(target_idx, repr(idx2char[target_idx])))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "MJdfPmdqzf-R"
   },
   "source": [
    "### 创建培训批次\n",
    "我们过去常常tf.data将文本拆分为可管理的序列。但在将这些数据输入模型之前，我们需要对数据进行混洗并将其打包成批。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 35
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 10178,
     "status": "ok",
     "timestamp": 1564756505027,
     "user": {
      "displayName": "Will Chen",
      "photoUrl": "",
      "userId": "01179718990779759737"
     },
     "user_tz": -480
    },
    "id": "p2pGotuNzf-S",
    "outputId": "3d934eef-7b8d-4391-af37-67de1d850b7e"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<BatchDataset shapes: ((64, 100), (64, 100)), types: (tf.int64, tf.int64)>"
      ]
     },
     "execution_count": 16,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Batch size\n",
    "BATCH_SIZE = 64\n",
    "\n",
    "# Buffer size to shuffle the dataset\n",
    "# (TF data is designed to work with possibly infinite sequences,\n",
    "# so it doesn't attempt to shuffle the entire sequence in memory. Instead,\n",
    "# it maintains a buffer in which it shuffles elements).\n",
    "BUFFER_SIZE = 10000\n",
    "\n",
    "dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)\n",
    "\n",
    "dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "r6oUuElIMgVx"
   },
   "source": [
    "## 建立模型"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "m8gPwEjRzf-Z"
   },
   "source": [
    "\n",
    "使用tf.keras.Sequential来定义模型。对于这个简单的例子，三层用于定义我们的模型：\n",
    "\n",
    "- tf.keras.layers.Embedding：输入图层。一个可训练的查找表，它将每个字符的数字映射到具有embedding_dim维度的向量;\n",
    "- tf.keras.layers.GRU：一种具有大小的RNN units=rnn_units（您也可以在此处使用LSTM层。）\n",
    "- tf.keras.layers.Dense：输出层，带vocab_size输出。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "zHT8cLh7EAsg"
   },
   "outputs": [],
   "source": [
    "# Length of the vocabulary in chars\n",
    "vocab_size = len(vocab)\n",
    "\n",
    "# The embedding dimension\n",
    "embedding_dim = 256\n",
    "\n",
    "# Number of RNN units\n",
    "rnn_units = 1024"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "MtCrdfzEI2N0"
   },
   "outputs": [],
   "source": [
    "def build_model(vocab_size, embedding_dim, rnn_units, batch_size):\n",
    "  model = tf.keras.Sequential([\n",
    "    tf.keras.layers.Embedding(vocab_size, embedding_dim,\n",
    "                              batch_input_shape=[batch_size, None]),\n",
    "    tf.keras.layers.LSTM(rnn_units,\n",
    "                        return_sequences=True,\n",
    "                        stateful=True,\n",
    "                        recurrent_initializer='glorot_uniform'),\n",
    "    tf.keras.layers.Dense(vocab_size)\n",
    "  ])\n",
    "  return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "wwsrpOik5zhv"
   },
   "outputs": [],
   "source": [
    "model = build_model(\n",
    "  vocab_size = len(vocab),\n",
    "  embedding_dim=embedding_dim,\n",
    "  rnn_units=rnn_units,\n",
    "  batch_size=BATCH_SIZE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "RkA5upJIJ7W7"
   },
   "source": [
    "对于每个字符，模型查找嵌入，以嵌入作为输入一次运行GRU，并应用密集层生成预测下一个字符的对数似然的logits：\n",
    "\n",
    "![通过模型的数据图](https://tensorflow.org/tutorials/beta/text/images/text_generation_training.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "-ubPo0_9Prjb"
   },
   "source": [
    "## 测试模型\n",
    "现在运行模型以查看它的行为符合预期。\n",
    "\n",
    "首先检查输出的形状："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 35
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 14644,
     "status": "ok",
     "timestamp": 1564756509531,
     "user": {
      "displayName": "Will Chen",
      "photoUrl": "",
      "userId": "01179718990779759737"
     },
     "user_tz": -480
    },
    "id": "C-_70kKAPrPU",
    "outputId": "c23fee0b-d816-4d3c-b363-e26693450ccf"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(64, 100, 65) # (batch_size, sequence_length, vocab_size)\n"
     ]
    }
   ],
   "source": [
    "for input_example_batch, target_example_batch in dataset.take(1):\n",
    "  example_batch_predictions = model(input_example_batch)\n",
    "  print(example_batch_predictions.shape, \"# (batch_size, sequence_length, vocab_size)\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Q6NzLBi4VM4o"
   },
   "source": [
    "在上面的例子中，输入的序列长度是，100但模型可以在任何长度的输入上运行："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 269
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 14636,
     "status": "ok",
     "timestamp": 1564756509532,
     "user": {
      "displayName": "Will Chen",
      "photoUrl": "",
      "userId": "01179718990779759737"
     },
     "user_tz": -480
    },
    "id": "vPGmAAXmVLGC",
    "outputId": "d0b9e5b6-33ca-456f-eec5-0295777b91d0"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "embedding (Embedding)        (64, None, 256)           16640     \n",
      "_________________________________________________________________\n",
      "lstm (LSTM)                  (64, None, 1024)          5246976   \n",
      "_________________________________________________________________\n",
      "dense (Dense)                (64, None, 65)            66625     \n",
      "=================================================================\n",
      "Total params: 5,330,241\n",
      "Trainable params: 5,330,241\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "uwv0gEkURfx1"
   },
   "source": [
    "为了从模型中获得实际预测，我们需要从输出分布中进行采样，以获得实际的字符索引。此分布由字符词汇表上的logits定义。\n",
    "> 注意：从这个分布中进行采样非常重要，因为采用分布的argmax可以很容易地将模型卡在循环中。\n",
    "\n",
    "尝试批处理中的第一个示例："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "4V4MfFg0RQJg"
   },
   "outputs": [],
   "source": [
    "sampled_indices = tf.random.categorical(example_batch_predictions[0], num_samples=1)\n",
    "sampled_indices = tf.squeeze(sampled_indices,axis=-1).numpy()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "QM1Vbxs_URw5"
   },
   "source": [
    "这使我们在每个时间步都预测下一个字符索引："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 125
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 14593,
     "status": "ok",
     "timestamp": 1564756509532,
     "user": {
      "displayName": "Will Chen",
      "photoUrl": "",
      "userId": "01179718990779759737"
     },
     "user_tz": -480
    },
    "id": "YqFMUQc_UFgM",
    "outputId": "6046fe90-c765-4edd-bb56-2e29557b5fe8"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([41,  3, 61, 21, 12, 44, 40, 24, 54, 19, 57, 51, 53, 45,  9, 52, 62,\n",
       "       46, 58, 12, 62, 54, 62, 22, 25, 23, 25,  0, 42, 53, 43,  5, 45, 51,\n",
       "       64, 57, 20, 13, 14, 14, 36, 55, 19,  5, 24, 15, 15, 62,  9, 53, 54,\n",
       "       46, 28, 37, 30, 29, 32, 64, 29, 51, 52,  0, 54, 51, 50,  7, 51, 37,\n",
       "       49, 33, 23, 22, 46, 56, 25, 43, 52, 26,  9, 33, 56, 45,  7, 10, 11,\n",
       "       58, 45, 29, 64, 61, 37, 64, 24, 25, 62, 47,  9, 42, 57, 18])"
      ]
     },
     "execution_count": 23,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sampled_indices"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "LfLtsP3mUhCG"
   },
   "source": [
    "解码这些以查看此未经训练的模型预测的文本："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 107
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 14578,
     "status": "ok",
     "timestamp": 1564756509533,
     "user": {
      "displayName": "Will Chen",
      "photoUrl": "",
      "userId": "01179718990779759737"
     },
     "user_tz": -480
    },
    "id": "xWcFwPwLSo05",
    "outputId": "5bd7b655-3417-4209-b5fb-394f0e62ab1f"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Input: \n",
      " 'g still, in cleansing them from tears.\\nNow sir, the sound that tells what hour it is\\nAre clamorous g'\n",
      "\n",
      "Next Char Predictions: \n",
      " \"c$wI?fbLpGsmog3nxht?xpxJMKM\\ndoe'gmzsHABBXqG'LCCx3ophPYRQTzQmn\\npml-mYkUKJhrMenN3Urg-:;tgQzwYzLMxi3dsF\"\n"
     ]
    }
   ],
   "source": [
    "print(\"Input: \\n\", repr(\"\".join(idx2char[input_example_batch[0]])))\n",
    "print()\n",
    "print(\"Next Char Predictions: \\n\", repr(\"\".join(idx2char[sampled_indices ])))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "LJL0Q0YPY6Ee"
   },
   "source": [
    "## 训练模型"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "YCbHQHiaa4Ic"
   },
   "source": [
    "此时，问题可以被视为标准分类问题。给定先前的RNN状态，以及此时间步的输入，预测下一个字符的类。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "trpqTWyvk0nr"
   },
   "source": [
    "### 附加优化器和损失函数"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "UAjbjY03eiQ4"
   },
   "source": [
    "标准tf.keras.losses.sparse_categorical_crossentropy损失函数在这种情况下有效，因为它应用于预测的最后一个维度。\n",
    "\n",
    "因为我们的模型返回logits，所以我们需要设置from_logits标志。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 53
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 14568,
     "status": "ok",
     "timestamp": 1564756509533,
     "user": {
      "displayName": "Will Chen",
      "photoUrl": "",
      "userId": "01179718990779759737"
     },
     "user_tz": -480
    },
    "id": "4HrXTACTdzY-",
    "outputId": "90eb3df2-4df4-4732-e28a-4a3bd3473f11"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Prediction shape:  (64, 100, 65)  # (batch_size, sequence_length, vocab_size)\n",
      "scalar_loss:       4.175114\n"
     ]
    }
   ],
   "source": [
    "def loss(labels, logits):\n",
    "  return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)\n",
    "\n",
    "example_batch_loss  = loss(target_example_batch, example_batch_predictions)\n",
    "print(\"Prediction shape: \", example_batch_predictions.shape, \" # (batch_size, sequence_length, vocab_size)\")\n",
    "print(\"scalar_loss:      \", example_batch_loss.numpy().mean())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "jeOXriLcymww"
   },
   "source": [
    "使用该tf.keras.Model.compile方法配置培训过程。我们将使用tf.keras.optimizers.Adam默认参数和损失函数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "DDl1_Een6rL0"
   },
   "outputs": [],
   "source": [
    "model.compile(optimizer='adam', loss=loss)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "ieSJdchZggUj"
   },
   "source": [
    "### 配置检查点"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "C6XBUUavgF56"
   },
   "source": [
    "使用a tf.keras.callbacks.ModelCheckpoint确保在培训期间保存检查点："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "W6fWTriUZP-n"
   },
   "outputs": [],
   "source": [
    "# Directory where the checkpoints will be saved\n",
    "checkpoint_dir = './training_checkpoints'\n",
    "# Name of the checkpoint files\n",
    "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt_{epoch}\")\n",
    "\n",
    "checkpoint_callback=tf.keras.callbacks.ModelCheckpoint(\n",
    "    filepath=checkpoint_prefix,\n",
    "    save_weights_only=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "3Ky3F_BhgkTW"
   },
   "source": [
    "### 执行培训"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "IxdOA-rgyGvs"
   },
   "source": [
    "为了使训练时间合理，使用10个时期来训练模型。在Colab中，将运行时设置为GPU以便更快地进行训练。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "7yGBE2zxMMHs"
   },
   "outputs": [],
   "source": [
    "EPOCHS=10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 377
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 153379,
     "status": "ok",
     "timestamp": 1564756648387,
     "user": {
      "displayName": "Will Chen",
      "photoUrl": "",
      "userId": "01179718990779759737"
     },
     "user_tz": -480
    },
    "id": "UK-hmKjYVoll",
    "outputId": "794f87e8-ff8d-43ef-d4da-472f94f0cbba"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/10\n",
      "172/172 [==============================] - 14s 83ms/step - loss: 2.6193\n",
      "Epoch 2/10\n",
      "172/172 [==============================] - 13s 78ms/step - loss: 1.9240\n",
      "Epoch 3/10\n",
      "172/172 [==============================] - 14s 80ms/step - loss: 1.6648\n",
      "Epoch 4/10\n",
      "172/172 [==============================] - 14s 81ms/step - loss: 1.5218\n",
      "Epoch 5/10\n",
      "172/172 [==============================] - 14s 82ms/step - loss: 1.4335\n",
      "Epoch 6/10\n",
      "172/172 [==============================] - 14s 82ms/step - loss: 1.3709\n",
      "Epoch 7/10\n",
      "172/172 [==============================] - 14s 81ms/step - loss: 1.3206\n",
      "Epoch 8/10\n",
      "172/172 [==============================] - 14s 80ms/step - loss: 1.2758\n",
      "Epoch 9/10\n",
      "172/172 [==============================] - 14s 80ms/step - loss: 1.2330\n",
      "Epoch 10/10\n",
      "172/172 [==============================] - 14s 80ms/step - loss: 1.1915\n"
     ]
    }
   ],
   "source": [
    "history = model.fit(dataset, epochs=EPOCHS, callbacks=[checkpoint_callback])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "kKkD5M6eoSiN"
   },
   "source": [
    "## 生成文本"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "JIPcXllKjkdr"
   },
   "source": [
    "### 恢复最新的检查点"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "LyeYRiuVjodY"
   },
   "source": [
    "\n",
    "要使此预测步骤简单，请使用批处理大小1。\n",
    "\n",
    "由于RNN状态从时间步长传递到时间步的方式，模型一旦构建就只接受固定的批量大小。\n",
    "\n",
    "要使用不同的模型运行模型batch_size，我们需要重建模型并从检查点恢复权重。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 35
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 153371,
     "status": "ok",
     "timestamp": 1564756648388,
     "user": {
      "displayName": "Will Chen",
      "photoUrl": "",
      "userId": "01179718990779759737"
     },
     "user_tz": -480
    },
    "id": "zk2WJ2-XjkGz",
    "outputId": "9088cc5b-bafa-49cb-a3a4-ee5341b578ce"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'./training_checkpoints/ckpt_10'"
      ]
     },
     "execution_count": 30,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf.train.latest_checkpoint(checkpoint_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "LycQ-ot_jjyu"
   },
   "outputs": [],
   "source": [
    "model = build_model(vocab_size, embedding_dim, rnn_units, batch_size=1)\n",
    "\n",
    "model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))\n",
    "\n",
    "model.build(tf.TensorShape([1, None]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 269
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 154683,
     "status": "ok",
     "timestamp": 1564756649725,
     "user": {
      "displayName": "Will Chen",
      "photoUrl": "",
      "userId": "01179718990779759737"
     },
     "user_tz": -480
    },
    "id": "71xa6jnYVrAN",
    "outputId": "66ae1561-327f-47b9-c1eb-a6226417d126"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential_1\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "embedding_1 (Embedding)      (1, None, 256)            16640     \n",
      "_________________________________________________________________\n",
      "lstm_1 (LSTM)                (1, None, 1024)           5246976   \n",
      "_________________________________________________________________\n",
      "dense_1 (Dense)              (1, None, 65)             66625     \n",
      "=================================================================\n",
      "Total params: 5,330,241\n",
      "Trainable params: 5,330,241\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "DjGz1tDkzf-u"
   },
   "source": [
    "预测循环\n",
    "以下代码块生成文本：\n",
    "\n",
    "- 它首先选择一个起始字符串，初始化RNN状态并设置要生成的字符数。\n",
    "\n",
    "- 使用起始字符串和RNN状态获取下一个字符的预测分布。\n",
    "\n",
    "- 然后，使用分类分布来计算预测字符的索引。使用此预测字符作为模型的下一个输入。\n",
    "\n",
    "- 模型返回的RNN状态被反馈到模型中，以便它现在具有更多上下文，而不是仅有一个单词。在预测下一个单词之后，修改后的RNN状态再次反馈到模型中，这是它从先前预测的单词获得更多上下文时的学习方式。\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "![要生成文本，模型的输出将反馈到输入](https://tensorflow.org/tutorials/beta/text/images/text_generation_sampling.png)\n",
    "查看生成的文本，您将看到模型知道何时大写，制作段落并模仿类似莎士比亚的写作词汇。由于训练时代数量较少，尚未学会形成连贯的句子。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "WvuwZBX5Ogfd"
   },
   "outputs": [],
   "source": [
    "def generate_text(model, start_string):\n",
    "  # Evaluation step (generating text using the learned model)\n",
    "\n",
    "  # Number of characters to generate\n",
    "  num_generate = 1000\n",
    "\n",
    "  # Converting our start string to numbers (vectorizing)\n",
    "  input_eval = [char2idx[s] for s in start_string]\n",
    "  input_eval = tf.expand_dims(input_eval, 0)\n",
    "\n",
    "  # Empty string to store our results\n",
    "  text_generated = []\n",
    "\n",
    "  # Low temperatures results in more predictable text.\n",
    "  # Higher temperatures results in more surprising text.\n",
    "  # Experiment to find the best setting.\n",
    "  temperature = 1.0\n",
    "\n",
    "  # Here batch size == 1\n",
    "  model.reset_states()\n",
    "  for i in range(num_generate):\n",
    "      predictions = model(input_eval)\n",
    "      # remove the batch dimension\n",
    "      predictions = tf.squeeze(predictions, 0)\n",
    "\n",
    "      # using a categorical distribution to predict the word returned by the model\n",
    "      predictions = predictions / temperature\n",
    "      predicted_id = tf.random.categorical(predictions, num_samples=1)[-1,0].numpy()\n",
    "\n",
    "      # We pass the predicted word as the next input to the model\n",
    "      # along with the previous hidden state\n",
    "      input_eval = tf.expand_dims([predicted_id], 0)\n",
    "\n",
    "      text_generated.append(idx2char[predicted_id])\n",
    "\n",
    "  return (start_string + ''.join(text_generated))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 55
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 159152,
     "status": "ok",
     "timestamp": 1564756654219,
     "user": {
      "displayName": "Will Chen",
      "photoUrl": "",
      "userId": "01179718990779759737"
     },
     "user_tz": -480
    },
    "id": "ktovv0RFhrkn",
    "outputId": "e7139557-b924-4656-da8e-a7e3f15ddc98"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ROMEO: VINJBYOO3U$zZWV?NU3UYZfUj&XYK3XKJjUY&&QKPQPQUXOzW?WGZJZJMB&O$qUQj3QUFJ$Wq&WxvRZj&zV'UY$NYVYDYXTrzJxY$QjQ$X&GjWOIUQAq.;QxZXVQzq&GGACjj3qGV&ZDWZzYYYU!&qZX$GYq&Z&vYVX&QCPUzzCZ!UQZVYzGBqY$B333&zk3z$J3UVE$OG!EE!NUY&PXjqVjMIfUYXq33?MjOQ3Y3XZzU$GXXWLZzZUYYG&KENvN&jJ3MQGjYWB3QxjxAU!ZQZ-JDJ$RUYqXDZ&!&jWTjJUWEV.YqYZZ3xX&wL&z3j$z&jUVQzJQj&DjWxYU.jQjY&YQD3j$G.&?,GNjjVY&jZWlQXUzz&kzZUJZjjUcUG&&jZJGJ$zVUjQ-VzS&G3JN&3&qWY&zxZZV$DQzYzYQjUvBAQS-3YY&XDVjzBNBJVVU$&QUwZxIjA$zQDQqZJZVU&J3NBZJVJNzTNqY33YW3$qW&DVEZVOQIqXXVON&&YXN&kJ3qY3&&GBxY!XS!Y&SUJ$Jz&kEU3&UzUZUQqNGqD$XjXUyYZN$Qj&WVYYEDXzqOWXYPU&zVEUjUY$DDJO$U&VVzGqjUkD&qjQQV&Yjjj$W&JSZY&&GjQ.qYj$PYUDY$&BWzZqqYXVROJW$YYzAjqQ&$&J&ZUVUqK$3M?XMq&UTZq&XqjjOWjMYX$&j$jUYQGEUMwUMDqYUDZjJYYzUqUW&QaYZVWjAV:UQ&EvPDxZ$Z$zqOUTzUYM&YFUHY;ZG&U&bx!jW3U! BUzJQJ3AUG&3NjJZ&3&Q$J-VYNSUz!&ZMqAYQUj:&UqI&X;xzQBUQQAG$z&X&zxXUIY3VUWE$Y;&U&Z&XZQq3Q$V&YJ&Q3Z&KUQDJbQ&LBBT?X&YA&xG?CjD&&DGVU$YCGJD3V$SGxQU$HAJUXUZDj&jUkxUqZ$DXTDG3YW&ZXjqO&GzUX3U3&OZUj!B$3x&3j!3JXXUQ&YW$UDkqPA3DZVQXJ&G\n"
     ]
    }
   ],
   "source": [
    "print(generate_text(model, start_string=u\"ROMEO: \"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "AM2Uma_-yVIq"
   },
   "source": [
    "你可以做的最简单的事情是改善结果，以便更长时间地训练它（尝试EPOCHS=30）。\n",
    "\n",
    "您还可以尝试使用不同的起始字符串，或尝试添加另一个RNN图层以提高模型的准确性，或者调整温度参数以生成更多或更少的随机预测。\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Y4QwTjAM6A2O"
   },
   "source": [
    "## 上述训练程序很简单，但不会给你太多控制。\n",
    "\n",
    "所以现在您已经了解了如何手动运行模型，让我们解压缩训练循环，并自己实现。例如，如果要实施课程学习以帮助稳定模型的开环输出，这就是一个起点。\n",
    "\n",
    "我们将用于tf.GradientTape跟踪渐变。您可以通过阅读热切的执行指南来了解有关此方法的更多信息。\n",
    "\n",
    "该程序的工作原理如下：\n",
    "\n",
    "- 首先，初始化RNN状态。我们通过调用tf.keras.Model.reset_states方法来完成此操作。\n",
    "\n",
    "- 接下来，迭代数据集（逐批）并计算与每个数据集关联的预测。\n",
    "\n",
    "- 打开a tf.GradientTape，并计算该上下文中的预测和损失。\n",
    "\n",
    "- 使用该tf.GradientTape.grads方法计算相对于模型变量的损失梯度。\n",
    "\n",
    "- 最后，使用优化器的tf.train.Optimizer.apply_gradients方法向下迈出一步。\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "_XAm7eCoKULT"
   },
   "outputs": [],
   "source": [
    "model = build_model(\n",
    "  vocab_size = len(vocab),\n",
    "  embedding_dim=embedding_dim,\n",
    "  rnn_units=rnn_units,\n",
    "  batch_size=BATCH_SIZE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "qUKhnZtMVpoJ"
   },
   "outputs": [],
   "source": [
    "optimizer = tf.keras.optimizers.Adam()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "b4kH1o0leVIp"
   },
   "outputs": [],
   "source": [
    "@tf.function\n",
    "def train_step(inp, target):\n",
    "  with tf.GradientTape() as tape:\n",
    "    predictions = model(inp)\n",
    "    loss = tf.reduce_mean(\n",
    "        tf.keras.losses.sparse_categorical_crossentropy(\n",
    "            target, predictions, from_logits=True))\n",
    "  grads = tape.gradient(loss, model.trainable_variables)\n",
    "  optimizer.apply_gradients(zip(grads, model.trainable_variables))\n",
    "\n",
    "  return loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 295217,
     "status": "ok",
     "timestamp": 1564756790330,
     "user": {
      "displayName": "Will Chen",
      "photoUrl": "",
      "userId": "01179718990779759737"
     },
     "user_tz": -480
    },
    "id": "d4tSNwymzf-q",
    "outputId": "894696d7-6ac7-4e0e-f469-ce07f98ac3d3"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING: Logging before flag parsing goes to stderr.\n",
      "W0802 14:37:35.460247 139895554488192 util.py:244] Unresolved object in checkpoint: (root).optimizer\n",
      "W0802 14:37:35.461703 139895554488192 util.py:244] Unresolved object in checkpoint: (root).optimizer.iter\n",
      "W0802 14:37:35.462915 139895554488192 util.py:244] Unresolved object in checkpoint: (root).optimizer.beta_1\n",
      "W0802 14:37:35.464316 139895554488192 util.py:244] Unresolved object in checkpoint: (root).optimizer.beta_2\n",
      "W0802 14:37:35.465240 139895554488192 util.py:244] Unresolved object in checkpoint: (root).optimizer.decay\n",
      "W0802 14:37:35.467100 139895554488192 util.py:244] Unresolved object in checkpoint: (root).optimizer.learning_rate\n",
      "W0802 14:37:35.468499 139895554488192 util.py:244] Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.embeddings\n",
      "W0802 14:37:35.469516 139895554488192 util.py:244] Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-2.kernel\n",
      "W0802 14:37:35.470415 139895554488192 util.py:244] Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-2.bias\n",
      "W0802 14:37:35.471374 139895554488192 util.py:244] Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-1.cell.kernel\n",
      "W0802 14:37:35.472301 139895554488192 util.py:244] Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-1.cell.recurrent_kernel\n",
      "W0802 14:37:35.473214 139895554488192 util.py:244] Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-1.cell.bias\n",
      "W0802 14:37:35.474144 139895554488192 util.py:244] Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.embeddings\n",
      "W0802 14:37:35.475064 139895554488192 util.py:244] Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-2.kernel\n",
      "W0802 14:37:35.476021 139895554488192 util.py:244] Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-2.bias\n",
      "W0802 14:37:35.476972 139895554488192 util.py:244] Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-1.cell.kernel\n",
      "W0802 14:37:35.477811 139895554488192 util.py:244] Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-1.cell.recurrent_kernel\n",
      "W0802 14:37:35.478764 139895554488192 util.py:244] Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-1.cell.bias\n",
      "W0802 14:37:35.479741 139895554488192 util.py:252] A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/alpha/guide/checkpoints#loading_mechanics for details.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1 Batch 0 Loss 4.174371719360352\n",
      "Epoch 1 Batch 100 Loss 2.323850631713867\n",
      "Epoch 1 Loss 2.1186\n",
      "Time taken for 1 epoch 15.880910158157349 sec\n",
      "\n",
      "Epoch 2 Batch 0 Loss 2.0843210220336914\n",
      "Epoch 2 Batch 100 Loss 1.83076810836792\n",
      "Epoch 2 Loss 1.7715\n",
      "Time taken for 1 epoch 13.457166194915771 sec\n",
      "\n",
      "Epoch 3 Batch 0 Loss 1.7104278802871704\n",
      "Epoch 3 Batch 100 Loss 1.5949002504348755\n",
      "Epoch 3 Loss 1.5931\n",
      "Time taken for 1 epoch 13.458112716674805 sec\n",
      "\n",
      "Epoch 4 Batch 0 Loss 1.5357296466827393\n",
      "Epoch 4 Batch 100 Loss 1.4776993989944458\n",
      "Epoch 4 Loss 1.4852\n",
      "Time taken for 1 epoch 13.257809162139893 sec\n",
      "\n",
      "Epoch 5 Batch 0 Loss 1.4319688081741333\n",
      "Epoch 5 Batch 100 Loss 1.4038314819335938\n",
      "Epoch 5 Loss 1.4126\n",
      "Time taken for 1 epoch 13.24284291267395 sec\n",
      "\n",
      "Epoch 6 Batch 0 Loss 1.3629602193832397\n",
      "Epoch 6 Batch 100 Loss 1.3476572036743164\n",
      "Epoch 6 Loss 1.3545\n",
      "Time taken for 1 epoch 13.140316724777222 sec\n",
      "\n",
      "Epoch 7 Batch 0 Loss 1.3121309280395508\n",
      "Epoch 7 Batch 100 Loss 1.2977547645568848\n",
      "Epoch 7 Loss 1.3026\n",
      "Time taken for 1 epoch 13.17615532875061 sec\n",
      "\n",
      "Epoch 8 Batch 0 Loss 1.2648319005966187\n",
      "Epoch 8 Batch 100 Loss 1.2559295892715454\n",
      "Epoch 8 Loss 1.2583\n",
      "Time taken for 1 epoch 13.247912645339966 sec\n",
      "\n",
      "Epoch 9 Batch 0 Loss 1.2237274646759033\n",
      "Epoch 9 Batch 100 Loss 1.2122169733047485\n",
      "Epoch 9 Loss 1.2092\n",
      "Time taken for 1 epoch 13.323989629745483 sec\n",
      "\n",
      "Epoch 10 Batch 0 Loss 1.1861408948898315\n",
      "Epoch 10 Batch 100 Loss 1.165980339050293\n",
      "Epoch 10 Loss 1.1561\n",
      "Time taken for 1 epoch 13.36545991897583 sec\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Training step\n",
    "EPOCHS = 10\n",
    "\n",
    "for epoch in range(EPOCHS):\n",
    "  start = time.time()\n",
    "\n",
    "  # initializing the hidden state at the start of every epoch\n",
    "  # initally hidden is None\n",
    "  hidden = model.reset_states()\n",
    "\n",
    "  for (batch_n, (inp, target)) in enumerate(dataset):\n",
    "    loss = train_step(inp, target)\n",
    "\n",
    "    if batch_n % 100 == 0:\n",
    "      template = 'Epoch {} Batch {} Loss {}'\n",
    "      print(template.format(epoch+1, batch_n, loss))\n",
    "\n",
    "  # saving (checkpoint) the model every 5 epochs\n",
    "  if (epoch + 1) % 5 == 0:\n",
    "    model.save_weights(checkpoint_prefix.format(epoch=epoch))\n",
    "\n",
    "  print ('Epoch {} Loss {:.4f}'.format(epoch+1, loss))\n",
    "  print ('Time taken for 1 epoch {} sec\\n'.format(time.time() - start))\n",
    "\n",
    "model.save_weights(checkpoint_prefix.format(epoch=epoch))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "C_9USpdHlYtt"
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "last_runtime": {
    "build_target": "//learning/brain/python/client:tpu_hw_notebook",
    "kind": "private"
   },
   "name": "text_generation.ipynb",
   "provenance": [
    {
     "file_id": "https://github.com/tensorflow/docs/blob/master/site/en/r2/tutorials/text/text_generation.ipynb",
     "timestamp": 1564754299619
    }
   ],
   "toc_visible": true,
   "version": "0.3.2"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
